[InstCombine] Fold (trunc X) into X & Mask inside decomposeBitTestICmp#171195
[InstCombine] Fold (trunc X) into X & Mask inside decomposeBitTestICmp#171195
(trunc X) into X & Mask inside decomposeBitTestICmp#171195Conversation
|
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Tirthankar Mazumder (wermos) ChangesAddresses #170020. I'm not exactly sure what kind of Alive2 proof is required when the optimization has to do with I followed the suggestion given here: To do this, I had to make I'm also not sure if more tests are required or not. Full diff: https://github.com/llvm/llvm-project/pull/171195.diff 4 Files Affected:
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index b730a36488780..48cc85e719421 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -102,6 +102,15 @@ LLVM_ABI void computeKnownBitsFromContext(const Value *V, KnownBits &Known,
const SimplifyQuery &Q,
unsigned Depth = 0);
+/// Update \p Known with bits of \p V that are implied by \p Cmp.
+/// Comparisons involving `trunc V` are handled specially: known
+/// bits are computed for the truncated value and then extended to the bitwidth
+/// of \p V.
+LLVM_ABI void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
+ KnownBits &Known,
+ const SimplifyQuery &SQ,
+ bool Invert);
+
/// Using KnownBits LHS/RHS produce the known bits for logic op (and/xor/or).
LLVM_ABI KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I,
const KnownBits &KnownLHS,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 9cb6f19b9340c..5ab5f8cfccc7f 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -968,9 +968,9 @@ static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
}
}
-static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
- KnownBits &Known,
- const SimplifyQuery &SQ, bool Invert) {
+void llvm::computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
+ KnownBits &Known,
+ const SimplifyQuery &SQ, bool Invert) {
ICmpInst::Predicate Pred =
Invert ? Cmp->getInversePredicate() : Cmp->getPredicate();
Value *LHS = Cmp->getOperand(0);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ba5568b00441b..fa7c66d736c28 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -15,11 +15,13 @@
#include "llvm/Analysis/CmpInstAnalysis.h"
#include "llvm/Analysis/FloatingPointPredicateUtils.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -3376,9 +3378,13 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1);
+ // dbgs() << "LHS0 = " << *LHS0 << "\nLHS1 = " << *LHS1 << '\n';
+ // dbgs() << "RHS0 = " << *RHS0 << "\nRHS1 = " << *RHS1 << '\n';
+
const APInt *LHSC = nullptr, *RHSC = nullptr;
match(LHS1, m_APInt(LHSC));
match(RHS1, m_APInt(RHSC));
+ // dbgs() << "LHSC = " << *LHSC << "\nRHSC = " << *RHSC << '\n';
// (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B)
// (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B)
@@ -3575,6 +3581,40 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
return Builder.createIsFPClass(X, IsAnd ? FPClassTest::fcNormal
: ~FPClassTest::fcNormal);
+ if (!IsLogical && IsAnd) {
+ auto TryCandidate = [&](Value *X) -> Value * {
+ if (!X->getType()->isIntegerTy())
+ return nullptr;
+
+ Type *Ty = X->getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+
+ // KnownL and KnownR hold information deduced from the LHS icmp and RHS
+ // icmps, respectively
+ KnownBits KnownL(BitWidth), KnownR(BitWidth);
+
+ computeKnownBitsFromICmpCond(X, LHS, KnownL, Q, /*Invert=*/false);
+ computeKnownBitsFromICmpCond(X, RHS, KnownR, Q, /*Invert=*/false);
+
+ KnownBits Combined = KnownL.unionWith(KnownR);
+
+ // Avoid stomping on cases where one icmp alone determines X. Those are handled by more specific InstCombine folds.
+ if (KnownL.isConstant() || KnownR.isConstant())
+ return nullptr;
+
+ if (!Combined.isConstant())
+ return nullptr;
+
+ APInt ConstVal = Combined.getConstant();
+ return Builder.CreateICmpEQ(X, ConstantInt::get(Ty, ConstVal));
+ };
+
+ if (Value *Res = TryCandidate(LHS0))
+ return Res;
+ if (Value *Res = TryCandidate(RHS0))
+ return Res;
+ }
+
return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd);
}
diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
index 290e344acb980..9d69fadfa9627 100644
--- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll
+++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
@@ -702,9 +702,9 @@ define i1 @PR42691_10_logical(i32 %x) {
define i1 @substitute_constant_and_eq_eq(i8 %x, i8 %y) {
; CHECK-LABEL: @substitute_constant_and_eq_eq(
-; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42
-; CHECK-NEXT: [[R:%.*]] = and i1 [[C1]], [[TMP1]]
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42
+; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
; CHECK-NEXT: ret i1 [[R]]
;
%c1 = icmp eq i8 %x, 42
@@ -728,9 +728,9 @@ define i1 @substitute_constant_and_eq_eq_logical(i8 %x, i8 %y) {
define i1 @substitute_constant_and_eq_eq_commute(i8 %x, i8 %y) {
; CHECK-LABEL: @substitute_constant_and_eq_eq_commute(
-; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42
-; CHECK-NEXT: [[R:%.*]] = and i1 [[C1]], [[TMP1]]
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42
+; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
; CHECK-NEXT: ret i1 [[R]]
;
%c1 = icmp eq i8 %x, 42
@@ -741,9 +741,9 @@ define i1 @substitute_constant_and_eq_eq_commute(i8 %x, i8 %y) {
define i1 @substitute_constant_and_eq_eq_commute_logical(i8 %x, i8 %y) {
; CHECK-LABEL: @substitute_constant_and_eq_eq_commute_logical(
-; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42
-; CHECK-NEXT: [[R:%.*]] = and i1 [[C1]], [[TMP1]]
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42
+; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
; CHECK-NEXT: ret i1 [[R]]
;
%c1 = icmp eq i8 %x, 42
@@ -1392,12 +1392,12 @@ define i1 @bitwise_and_bitwise_and_icmps(i8 %x, i8 %y, i8 %z) {
define i1 @bitwise_and_bitwise_and_icmps_comm1(i8 %x, i8 %y, i8 %z) {
; CHECK-LABEL: @bitwise_and_bitwise_and_icmps_comm1(
-; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[Y:%.*]], 42
+; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 [[Y:%.*]], 42
; CHECK-NEXT: [[Z_SHIFT:%.*]] = shl nuw i8 1, [[Z:%.*]]
; CHECK-NEXT: [[TMP1:%.*]] = or i8 [[Z_SHIFT]], 1
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[X:%.*]], [[TMP1]]
-; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 [[TMP2]], [[TMP1]]
-; CHECK-NEXT: [[AND2:%.*]] = and i1 [[C1]], [[TMP3]]
+; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i8 [[TMP2]], [[TMP1]]
+; CHECK-NEXT: [[AND2:%.*]] = and i1 [[TMP3]], [[TMP4]]
; CHECK-NEXT: ret i1 [[AND2]]
;
%c1 = icmp eq i8 %y, 42
@@ -3721,3 +3721,28 @@ define i1 @merge_range_check_or(i8 %a) {
%and = or i1 %cmp1, %cmp2
ret i1 %and
}
+
+; Just a very complicated way of checking if v1 == 0.
+define i1 @complicated_zero_equality_test(i64 %v1) {
+; CHECK-LABEL: @complicated_zero_equality_test(
+; CHECK-NEXT: [[V5:%.*]] = icmp eq i64 [[V1:%.*]], 0
+; CHECK-NEXT: ret i1 [[V5]]
+;
+ %v2 = trunc i64 %v1 to i32
+ %v3 = icmp eq i32 %v2, 0
+ %v4 = icmp ult i64 %v1, 4294967296 ; 2 ^ 32
+ %v5 = and i1 %v4, %v3
+ ret i1 %v5
+}
+
+define i1 @commuted_complicated_zero_equality_test(i64 %v1) {
+; CHECK-LABEL: @commuted_complicated_zero_equality_test(
+; CHECK-NEXT: [[V5:%.*]] = icmp eq i64 [[V1:%.*]], 0
+; CHECK-NEXT: ret i1 [[V5]]
+;
+ %v2 = trunc i64 %v1 to i32
+ %v3 = icmp ult i64 %v1, 4294967296 ; 2 ^ 32
+ %v4 = icmp eq i32 %v2, 0
+ %v5 = and i1 %v4, %v3
+ ret i1 %v5
+}
|
|
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
I've addressed the email thing as well. |
|
Ping @dtcxzyw for review. |
|
if the trunc is repalced by an and this is already folded see https://alive2.llvm.org/ce/z/Whfa65 llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp Lines 204 to 212 in f803e46 |
You could use this format. |
Oh yes, we can simply handle this in |
|
Alright, I'll work on moving the patch you shared into |
|
I've redone the entire implementation. I removed my previous changes and modified |
(x < 2^32) & (trunc(x to i32) == 0) into x == 0(trunc X) into X & Mask inside decomposeBitTestICmp
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
|
It looks like my changes affected a lot of tests. What is the protocol for determining whether these changes are regressions or not? |
Regenerate the failing tests (you can use something like |
|
I manually checked the test diffs by feeding the changes into Alive2 and seeing if the transformation is valid or not:
|
|
The test diffs I just pushed are all of the form I'm attaching an Alive2 link for this transform for the |
| ; CHECK-NEXT: [[ORIENTATIONS:%.*]] = alloca [1 x [1 x %struct.x]], align 8 | ||
| ; CHECK-NEXT: [[ORIENTATIONS:%.*]] = alloca [1 x [1 x [[STRUCT_X:%.*]]]], align 8 |
There was a problem hiding this comment.
I don't know why this line changed. Is this a harmless change? I'm not 100% sure, but it looks like a cosmetic (variable renaming) change to me.
There was a problem hiding this comment.
as this is the only change now this file can be reverted
| ; CHECK-NEXT: [[ORIENTATIONS:%.*]] = alloca [1 x [1 x %struct.x]], align 8 | ||
| ; CHECK-NEXT: [[ORIENTATIONS:%.*]] = alloca [1 x [1 x [[STRUCT_X:%.*]]]], align 8 |
There was a problem hiding this comment.
as this is the only change now this file can be reverted
|
I've reverted the change in I'm going to close and reopen the PR to restart the CI because the AArch64 build failed due to a server timeout. |
andjo403
left a comment
There was a problem hiding this comment.
looks good to me wait for second reviewer.
dtcxzyw
left a comment
There was a problem hiding this comment.
LGTM. However, as I commented in #171195 (comment), the current implementation converts (trunc X) ==/!= C to (and X, mask) ==/!= 0 unconditionally and causes some irrevertible regressions. As a follow-up, we can try to preserve previous behavior by checking if the mask is getLowBitsSet(bit width of original operands). Then we can check the net effect.
|
Ping @nikic for final review |
Do you need me to merge this? |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/3/builds/27400 Here is the relevant piece of the build log for the reference |
…stICmp` (llvm#171195) Resolves llvm#170020. Added another case to the `ICmp::EQ`/`ICmp::NE` case in the switch inside `decomposeBitTestICmp` to convert `trunc X` into a `X & Mask`.
|
Hi @dtcxzyw, we see a surprising interaction of this optimization and memory sanitizer, which starts complaining about a use of uninitialized value in the code looking roughly like this (https://gcc.godbolt.org/z/M1MKcznhx): bool f(std::optional<int> x) {
return x == 0 || x == 1;
}IIUC, after SROA And the problem seems to be that the comparisons take into account potentially uninitialized lower 32 bits. Is this a valid thing to do in LLVM IR? |
To my knowledge, we don't model partial undef in LLVM. Can you find which transform breaks the check? If it involves two uses of the same value, we can simply add an isGuaranteedNotToBeUndef guard. |
@thurstond should know better what happens from the point of view of msan. |
Resolves #170020.
Added another case to the
ICmp::EQ/ICmp::NEcase in the switch insidedecomposeBitTestICmpto converttrunc Xinto aX & Mask.